# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Core model definition of YAMNet."""

import csv

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, layers

from . import features as features_lib


def _batch_norm(name, params):
    def _bn_layer(layer_input):
        return layers.BatchNormalization(
            name=name,
            center=params.batchnorm_center,
            scale=params.batchnorm_scale,
            epsilon=params.batchnorm_epsilon,
        )(layer_input)

    return _bn_layer


def _conv(name, kernel, stride, filters, params):
    def _conv_layer(layer_input):
        output = layers.Conv2D(
            name="{}/conv".format(name),
            filters=filters,
            kernel_size=kernel,
            strides=stride,
            padding=params.conv_padding,
            use_bias=False,
            activation=None,
        )(layer_input)
        output = _batch_norm("{}/conv/bn".format(name), params)(output)
        output = layers.ReLU(name="{}/relu".format(name))(output)
        return output

    return _conv_layer


def _separable_conv(name, kernel, stride, filters, params):
    def _separable_conv_layer(layer_input):
        output = layers.DepthwiseConv2D(
            name="{}/depthwise_conv".format(name),
            kernel_size=kernel,
            strides=stride,
            depth_multiplier=1,
            padding=params.conv_padding,
            use_bias=False,
            activation=None,
        )(layer_input)
        output = _batch_norm("{}/depthwise_conv/bn".format(name), params)(output)
        output = layers.ReLU(name="{}/depthwise_conv/relu".format(name))(output)
        output = layers.Conv2D(
            name="{}/pointwise_conv".format(name),
            filters=filters,
            kernel_size=(1, 1),
            strides=1,
            padding=params.conv_padding,
            use_bias=False,
            activation=None,
        )(output)
        output = _batch_norm("{}/pointwise_conv/bn".format(name), params)(output)
        output = layers.ReLU(name="{}/pointwise_conv/relu".format(name))(output)
        return output

    return _separable_conv_layer


_YAMNET_LAYER_DEFS = [
    # (layer_function, kernel, stride, num_filters)
    (_conv, [3, 3], 2, 32),
    (_separable_conv, [3, 3], 1, 64),
    (_separable_conv, [3, 3], 2, 128),
    (_separable_conv, [3, 3], 1, 128),
    (_separable_conv, [3, 3], 2, 256),
    (_separable_conv, [3, 3], 1, 256),
    (_separable_conv, [3, 3], 2, 512),
    (_separable_conv, [3, 3], 1, 512),
    (_separable_conv, [3, 3], 1, 512),
    (_separable_conv, [3, 3], 1, 512),
    (_separable_conv, [3, 3], 1, 512),
    (_separable_conv, [3, 3], 1, 512),
    (_separable_conv, [3, 3], 2, 1024),
    (_separable_conv, [3, 3], 1, 1024),
]


def yamnet(features, params):
    """Define the core YAMNet mode in Keras."""
    net = layers.Reshape(
        (params.patch_frames, params.patch_bands, 1),
        input_shape=(params.patch_frames, params.patch_bands),
    )(features)
    for (i, (layer_fun, kernel, stride, filters)) in enumerate(_YAMNET_LAYER_DEFS):
        net = layer_fun("layer{}".format(i + 1), kernel, stride, filters, params)(net)
    embeddings = layers.GlobalAveragePooling2D()(net)
    logits = layers.Dense(units=params.num_classes, use_bias=True)(embeddings)
    predictions = layers.Activation(activation=params.classifier_activation)(logits)
    return predictions, embeddings


def yamnet_frames_model(params):
    """Defines the YAMNet waveform-to-class-scores model.

    Args:
      params: An instance of Params containing hyperparameters.

    Returns:
      A model accepting (num_samples,) waveform input and emitting:
      - predictions: (num_patches, num_classes) matrix of class scores per time frame
      - embeddings: (num_patches, embedding size) matrix of embeddings per time frame
      - log_mel_spectrogram: (num_spectrogram_frames, num_mel_bins) spectrogram feature matrix
    """
    waveform = layers.Input(batch_shape=(None,), dtype=tf.float32)
    waveform_padded = features_lib.pad_waveform(waveform, params)
    (
        log_mel_spectrogram,
        features,
    ) = features_lib.waveform_to_log_mel_spectrogram_patches(waveform_padded, params)
    predictions, embeddings = yamnet(features, params)
    frames_model = Model(
        name="yamnet_frames",
        inputs=waveform,
        outputs=[predictions, embeddings, log_mel_spectrogram],
    )
    return frames_model


def class_names(class_map_csv):
    """Read the class name definition file and return a list of strings."""
    if tf.is_tensor(class_map_csv):
        class_map_csv = class_map_csv.numpy()
    with open(class_map_csv) as csv_file:
        reader = csv.reader(csv_file)
        next(reader)  # Skip header
        return np.array([display_name for (_, _, display_name) in reader])
